# -*- coding: utf-8 -*-
"""
Created on Thu Apr 13 09:34:52 2023

@author: lv
"""
import os
import torch
from torch.utils.data import Dataset
from typing import List, Tuple
import yaml

class QADataset(Dataset):
    def __init__(self, data_paths: List[Tuple[str,bool]], tokenizer: None, max_seq_length: int):
        self.tokenizer = tokenizer
        self.max_seq_length = max_seq_length
        self.datas = []
        for path,is_multi_dialogue in data_paths:
            self.datas += self.load_datas(path,is_multi_dialogue)

    def __len__(self):
        return len(self.datas)

    def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
        inputs, target = self.datas[index]
        return inputs, target
    
    def readHistory(self,qa_list: List[Tuple[str, str]], index: int) -> Tuple[torch.Tensor, torch.Tensor]:
        history = ""
        for i in range(0,index):
            history += qa_list[i][0]
            history += '\n'
            history += qa_list[i][1]
            history += '\n'
        return history

    #load_datas用于加载生成(input_seq,output_seq)对,形状为(batch_size, seq_len),
    def load_datas(self, data_path: str,is_multi_dialogue: bool) -> List[Tuple[torch.Tensor, torch.Tensor]]:
        datas = []
        #加载多轮对话
        for filename in os.listdir(data_path):
            qa_list = []
            file_path = os.path.join(data_path, filename)
            if filename.endswith(".txt"):
                qa_list = self.readQAfromfile(file_path,'txt')
            elif filename.endswith(".yml"):
                qa_list = self.readQAfromfile(file_path,'yml')
            elif filename.endswith(".conv"):
                qa_list = self.readQAfromfile(file_path,'conv')
            else:
                continue
            for i in range(0,len(qa_list)):
                if is_multi_dialogue:
                    query, target = self.makeItem(qa_list,i,True)
                    datas.append((query, target))
                query, target = self.makeItem(qa_list,i,False)
                datas.append((query, target))
        return datas
    
    #根据QA元组列表准备生成训练数据
    #target_seq在生成时会做处理不需要添加标记
    #input_seq，需要加入历史记录或者参考资料,从文本内容按顺序拼接就好，超过就截断，最新的话题在末尾
    def makeItem(self, qa_list: List[Tuple[str, str]], index: int,multi_dialogue = False) -> Tuple[torch.Tensor, torch.Tensor]:
        
        # 截取输出序列
        targetstr = qa_list[index][1].strip()

        #处理target_seq 直接读取并编码为向量
        target = self.tokenizer.encode(targetstr)
        
        #加上停止标识 向量为0
        target.append(self.tokenizer.end_token_id)
        target.append(self.tokenizer.end_token_id)#多加一个停止符号，使强制训练生效
        target = target[:self.max_seq_length]
        
        historystr = ''
        if multi_dialogue:
            historystr = self.readHistory(qa_list,index)
        #最新的提问在末尾
        querystr = historystr + qa_list[index][0]

        querystr = querystr.strip()

        #长度超过则截断，优先保留末尾
        if len(querystr) > self.max_seq_length:
            querystr = querystr[-self.max_seq_length:]
        

        #编码为向量
        query = self.tokenizer.encode(querystr)
        
        #print('querystr:'+querystr)
        #print('targetstr:'+targetstr)
        #print('query.shape:',query.shape)
        #print('target.shape:',target.shape)
        return torch.tensor(query), torch.tensor(target)

    def readQAfromfile(self, file_path, file_format='txt'):
        qa_list = []
        with open(file_path, "r", encoding="utf-8") as f:
            if file_format == 'txt':
                lines = f.readlines()
                question = None
                answer = ""
                for line in lines:
                    line = line.strip()
                    if line == None:
                        continue
                    if line.startswith("Q:"):
                        if question is not None and answer != "":
                            qa_list.append((question, answer))
                            question = None
                            answer = ""
                        question = line[2:]
                    elif line.startswith("A:"):
                        answer += line[2:] + "\n"
                    else:
                        if answer != "":
                            answer += line + "\n"
                        else:
                            if question:
                                question += line + "\n"
    
                if question is not None and answer != "":
                    qa_list.append((question, answer))
                    question = None
                    answer = ""
            elif file_format == 'yml':
                data = yaml.safe_load(f)
                # 提取问题和答案并存储到列表中
                for item in data['conversations']:
                    question = item[0]
                    answer = item[1]
                    qa_list.append((question, answer))
                            
                # 提取问题和答案并存储到列表中
                for item in data['conversations']:
                    question = item[0]
                    answer = item[1]
                    qa_list.append((question, answer))
            elif file_format == 'conv':
                lines = f.readlines()
                question = None
                answer = ""
                for line in lines:
                    if line.startswith('E '):
                        question = None
                        answer = ""
                    elif line.startswith('M '):
                        if question == None:
                            question = line[2:]
                        else:
                            answer = line[2:]
                            qa_list.append((question, answer))
            else:
                raise ValueError("Unsupported file format: {}".format(file_format))
        return qa_list
